#include "VltImage.h"

#include "VltDevice.h"

namespace sce::vlt
{
	std::atomic<uint64_t> VltImageView::s_cookie = { 0ull };

	VltImage::VltImage(
		VltDevice*                device,
		const VltImageCreateInfo& createInfo,
		VltMemoryAllocator&       memAlloc,
		VkMemoryPropertyFlags     memFlags) :
		m_device(device),
		m_info(createInfo), m_memFlags(memFlags)
	{
		// Copy the compatible view formats to a persistent array
		m_viewFormats.resize(createInfo.viewFormatCount);
		for (uint32_t i = 0; i < createInfo.viewFormatCount; i++)
			m_viewFormats[i] = createInfo.viewFormats[i];
		m_info.viewFormats = m_viewFormats.data();

		// If defined, we should provide a format list, which
		// allows some drivers to enable image compression
		VkImageFormatListCreateInfo formatList;
		formatList.sType           = VK_STRUCTURE_TYPE_IMAGE_FORMAT_LIST_CREATE_INFO;
		formatList.pNext           = nullptr;
		formatList.viewFormatCount = createInfo.viewFormatCount;
		formatList.pViewFormats    = createInfo.viewFormats;

		VkImageCreateInfo info;
		info.sType                 = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO;
		info.pNext                 = &formatList;
		info.flags                 = createInfo.flags;
		info.imageType             = createInfo.type;
		info.format                = createInfo.format;
		info.extent                = createInfo.extent;
		info.mipLevels             = createInfo.mipLevels;
		info.arrayLayers           = createInfo.numLayers;
		info.samples               = createInfo.sampleCount;
		info.tiling                = createInfo.tiling;
		info.usage                 = createInfo.usage;
		info.sharingMode           = VK_SHARING_MODE_EXCLUSIVE;
		info.queueFamilyIndexCount = 0;
		info.pQueueFamilyIndices   = nullptr;
		info.initialLayout         = createInfo.initialLayout;

		if (vkCreateImage(m_device->handle(),
						  &info, nullptr, &m_image.image) != VK_SUCCESS)
		{
			Logger::exception(util::str::formatex(
				"DxvkImage: Failed to create image:",
				"\n  Type:            ", info.imageType,
				"\n  Format:          ", info.format,
				"\n  Extent:          ", "(", info.extent.width,
				",", info.extent.height,
				",", info.extent.depth, ")",
				"\n  Mip levels:      ", info.mipLevels,
				"\n  Array layers:    ", info.arrayLayers,
				"\n  Samples:         ", info.samples,
				"\n  Usage:           ", info.usage,
				"\n  Tiling:          ", info.tiling));
		}

		// Get memory requirements for the image. We may enforce strict
		// alignment on non-linear images in order not to violate the
		// bufferImageGranularity limit, which may be greater than the
		// required resource memory alignment on some GPUs.
		VkMemoryDedicatedRequirements dedicatedRequirements;
		dedicatedRequirements.sType                       = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS;
		dedicatedRequirements.pNext                       = VK_NULL_HANDLE;
		dedicatedRequirements.prefersDedicatedAllocation  = VK_FALSE;
		dedicatedRequirements.requiresDedicatedAllocation = VK_FALSE;

		VkMemoryRequirements2 memReq;
		memReq.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2;
		memReq.pNext = &dedicatedRequirements;

		VkImageMemoryRequirementsInfo2 memReqInfo;
		memReqInfo.sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_REQUIREMENTS_INFO_2;
		memReqInfo.image = m_image.image;
		memReqInfo.pNext = VK_NULL_HANDLE;

		VkMemoryDedicatedAllocateInfo dedMemoryAllocInfo;
		dedMemoryAllocInfo.sType  = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO;
		dedMemoryAllocInfo.pNext  = VK_NULL_HANDLE;
		dedMemoryAllocInfo.buffer = VK_NULL_HANDLE;
		dedMemoryAllocInfo.image  = m_image.image;

		vkGetImageMemoryRequirements2(
			m_device->handle(), &memReqInfo, &memReq);

		if (info.tiling != VK_IMAGE_TILING_LINEAR && !dedicatedRequirements.prefersDedicatedAllocation)
		{
			memReq.memoryRequirements.size      = util::align(memReq.memoryRequirements.size, memAlloc.bufferImageGranularity());
			memReq.memoryRequirements.alignment = util::align(memReq.memoryRequirements.alignment, memAlloc.bufferImageGranularity());
		}

		// Use high memory priority for GPU-writable resources
		bool isGpuWritable = (m_info.access & (VK_ACCESS_SHADER_WRITE_BIT |
											   VK_ACCESS_COLOR_ATTACHMENT_READ_BIT |
											   VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT |
											   VK_ACCESS_DEPTH_STENCIL_ATTACHMENT_READ_BIT |
											   VK_ACCESS_DEPTH_STENCIL_ATTACHMENT_WRITE_BIT)) != 0;

		VltMemoryFlags hints(VltMemoryFlag::GpuReadable);

		if (isGpuWritable)
			hints.set(VltMemoryFlag::GpuWritable);

		// Ask driver whether we should be using a dedicated allocation
		m_image.memory = memAlloc.alloc(&memReq.memoryRequirements,
										dedicatedRequirements, dedMemoryAllocInfo, memFlags, hints);

		// Try to bind the allocated memory slice to the image
		if (vkBindImageMemory(m_device->handle(), m_image.image,
							  m_image.memory.memory(), m_image.memory.offset()) != VK_SUCCESS)
			Logger::exception("DxvkImage::DxvkImage: Failed to bind device memory");
	}

	VltImage::VltImage(
		VltDevice*                device,
		const VltImageCreateInfo& info,
		VkImage                   image) :
		m_device(device),
		m_info(info), m_image({ image })
	{
		m_viewFormats.resize(info.viewFormatCount);
		for (uint32_t i = 0; i < info.viewFormatCount; i++)
			m_viewFormats[i] = info.viewFormats[i];
		m_info.viewFormats = m_viewFormats.data();
	}

	VltImage::~VltImage()
	{
		// This is a bit of a hack to determine whether
		// the image is implementation-handled or not
		if (m_image.memory.memory() != VK_NULL_HANDLE)
			vkDestroyImage(m_device->handle(), m_image.image, nullptr);
	}

	VkSubresourceLayout VltImage::querySubresourceLayout(const VkImageSubresource& subresource) const
	{
		VkSubresourceLayout result;
		vkGetImageSubresourceLayout(
			m_device->handle(), m_image.image,
			&subresource, &result);
		return result;
	}

	VltImageView::VltImageView(
		VltDevice*                    device,
		const Rc<VltImage>&           image,
		const VltImageViewCreateInfo& info) :
		m_device(device),
		m_image(image), m_info(info), m_cookie(++s_cookie)
	{
		for (uint32_t i = 0; i < ViewCount; i++)
			m_views[i] = VK_NULL_HANDLE;

		switch (m_info.type)
		{
		case VK_IMAGE_VIEW_TYPE_1D:
		case VK_IMAGE_VIEW_TYPE_1D_ARRAY:
		{
			this->createView(VK_IMAGE_VIEW_TYPE_1D, 1);
			this->createView(VK_IMAGE_VIEW_TYPE_1D_ARRAY, m_info.numLayers);
		}
		break;

		case VK_IMAGE_VIEW_TYPE_2D:
		case VK_IMAGE_VIEW_TYPE_2D_ARRAY:
			this->createView(VK_IMAGE_VIEW_TYPE_2D, 1);
			/* fall through */

		case VK_IMAGE_VIEW_TYPE_CUBE:
		case VK_IMAGE_VIEW_TYPE_CUBE_ARRAY:
		{
			this->createView(VK_IMAGE_VIEW_TYPE_2D_ARRAY, m_info.numLayers);

			if (m_image->info().flags & VK_IMAGE_CREATE_CUBE_COMPATIBLE_BIT)
			{
				uint32_t cubeCount = m_info.numLayers / 6;

				if (cubeCount > 0)
				{
					this->createView(VK_IMAGE_VIEW_TYPE_CUBE, 6);
					this->createView(VK_IMAGE_VIEW_TYPE_CUBE_ARRAY, 6 * cubeCount);
				}
			}
		}
		break;

		case VK_IMAGE_VIEW_TYPE_3D:
		{
			this->createView(VK_IMAGE_VIEW_TYPE_3D, 1);

			if (m_image->info().flags & VK_IMAGE_CREATE_2D_ARRAY_COMPATIBLE_BIT && m_info.numLevels == 1)
			{
				this->createView(VK_IMAGE_VIEW_TYPE_2D, 1);
				this->createView(VK_IMAGE_VIEW_TYPE_2D_ARRAY, m_image->mipLevelExtent(m_info.minLevel).depth);
			}
		}
		break;

		default:
			Logger::exception(util::str::formatex("DxvkImageView: Invalid view type: ", m_info.type));
		}
	}

	VltImageView::~VltImageView()
	{
		for (uint32_t i = 0; i < ViewCount; i++)
			vkDestroyImageView(m_device->handle(), m_views[i], nullptr);
	}

	void VltImageView::createView(VkImageViewType type, uint32_t numLayers)
	{
		VkImageSubresourceRange subresourceRange;
		subresourceRange.aspectMask     = m_info.aspect;
		subresourceRange.baseMipLevel   = m_info.minLevel;
		subresourceRange.levelCount     = m_info.numLevels;
		subresourceRange.baseArrayLayer = m_info.minLayer;
		subresourceRange.layerCount     = numLayers;

		VkImageViewUsageCreateInfo viewUsage;
		viewUsage.sType = VK_STRUCTURE_TYPE_IMAGE_VIEW_USAGE_CREATE_INFO;
		viewUsage.pNext = nullptr;
		viewUsage.usage = m_info.usage;

		VkImageViewCreateInfo viewInfo;
		viewInfo.sType            = VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO;
		viewInfo.pNext            = &viewUsage;
		viewInfo.flags            = 0;
		viewInfo.image            = m_image->handle();
		viewInfo.viewType         = type;
		viewInfo.format           = m_info.format;
		viewInfo.components       = m_info.swizzle;
		viewInfo.subresourceRange = subresourceRange;

		if (m_info.usage == VK_IMAGE_USAGE_COLOR_ATTACHMENT_BIT)
		{
			viewInfo.components = {
				VK_COMPONENT_SWIZZLE_IDENTITY, VK_COMPONENT_SWIZZLE_IDENTITY,
				VK_COMPONENT_SWIZZLE_IDENTITY, VK_COMPONENT_SWIZZLE_IDENTITY
			};
		}

		if (vkCreateImageView(m_device->handle(),
							  &viewInfo, nullptr, &m_views[type]) != VK_SUCCESS)
		{
			Logger::exception(util::str::formatex(
				"DxvkImageView: Failed to create image view:"
				"\n  View type:       ",
				viewInfo.viewType,
				"\n  View format:     ", viewInfo.format,
				"\n  Subresources:    ",
				"\n    Aspect mask:   ", std::hex, viewInfo.subresourceRange.aspectMask,
				"\n    Mip levels:    ", viewInfo.subresourceRange.baseMipLevel, " - ",
				viewInfo.subresourceRange.levelCount,
				"\n    Array layers:  ", viewInfo.subresourceRange.baseArrayLayer, " - ",
				viewInfo.subresourceRange.layerCount,
				"\n  Image properties:",
				"\n    Type:          ", m_image->info().type,
				"\n    Format:        ", m_image->info().format,
				"\n    Extent:        ", "(", m_image->info().extent.width,
				",", m_image->info().extent.height,
				",", m_image->info().extent.depth, ")",
				"\n    Mip levels:    ", m_image->info().mipLevels,
				"\n    Array layers:  ", m_image->info().numLayers,
				"\n    Samples:       ", m_image->info().sampleCount,
				"\n    Usage:         ", std::hex, m_image->info().usage,
				"\n    Tiling:        ", m_image->info().tiling));
		}
	}

}  // namespace sce::vlt